# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppcls.arch import build_model
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.data import create_operators
from utils import get_cv_image
import sys

sys.path.append("..")


class FoodParser(object):
    def __init__(self, config, mode=None):
        # 初始化配置
        self.config = config
        self.device = paddle.set_device(self.config["Global"]["device"])
        self.model = build_model(self.config["Arch"])
        # 加载模型
        load_dygraph_pretrain(self.model, self.config["Global"]["pretrained_model"])
        self.preprocess_func = create_operators(self.config["Infer_Memery"]["transforms"])

    # 获取一个小图的一个特征
    @paddle.no_grad()
    def get_feature(self, img_data, row, col, chan, format):
        x = get_cv_image(img_data, row, col, format)
        preprocess_func = self.preprocess_func
        self.model.eval()
        for process in preprocess_func:
            x = process(x)
        batch_tensor = paddle.to_tensor([x])
        out = self.model(batch_tensor)
        feat = out['features']  # [1, 512]
        feat = feat.numpy()[0]
        return feat

    # 获取一个大图的多个特征
    @paddle.no_grad()
    def get_features(self, img_data, row, col, chan, format, bboxes):
        image = get_cv_image(img_data, row, col, format)
        preprocess_func = self.preprocess_func
        self.model.eval()
        batch_data, image_list = [], []
        for box in bboxes:
            x1, y1, x2, y2 = box[:]
            image_list.append(image[y1:y2, x1:x2])
        for idx, x in enumerate(image_list[:]):
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
        batch_tensor = paddle.to_tensor(batch_data)
        out = self.model(batch_tensor)
        fea = out['features']
        x = fea.numpy()
        # print(x, x.shape)
        return x, bboxes
